import os
import re
import math
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

# =========================
# 0) 路径与全局样式（你只改这块）
# =========================
INPUT_XLSX  = r"D:\economy\paper_1\data.xlsx"
OUTPUT_DIR  = r"D:\economy\paper_1\output"

# 是否输出“逐次计算完成”的小图 PDF + 数据 CSV
SAVE_INTERMEDIATE = True

# 图像字号（整体调大）
FONT_BASE   = 16
FONT_TITLE  = 20
FONT_SUB    = 18
FONT_TICK   = 13

# 门槛回归bootstrap次数（越大越稳但越慢；建议300~800）
N_BOOT      = 500
RANDOM_SEED = 202501

# 关键年份（用于图4-1(d)箱线图）
KEY_YEARS   = [2000, 2007, 2014, 2020]

# 四个阶段（固定为你的论文结构）
STAGES = [(2000, 2005), (2005, 2010), (2010, 2015), (2015, 2020)]

# 核心城市（自动做“成都市/重庆市”归一）
CORE_CITIES = ["成都", "重庆"]

# =========================
# 0.1) 输出目录（自动创建）
# =========================
INTERMEDIATE_DIR = os.path.join(OUTPUT_DIR, "_intermediate")
INTER_PDF_DIR    = os.path.join(INTERMEDIATE_DIR, "pdf")
INTER_CSV_DIR    = os.path.join(INTERMEDIATE_DIR, "csv")

os.makedirs(OUTPUT_DIR, exist_ok=True)
if SAVE_INTERMEDIATE:
    os.makedirs(INTERMEDIATE_DIR, exist_ok=True)
    os.makedirs(INTER_PDF_DIR, exist_ok=True)
    os.makedirs(INTER_CSV_DIR, exist_ok=True)

def _save_csv(df: pd.DataFrame, path: str, index: bool = False):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    df.to_csv(path, index=index, encoding="utf-8-sig")
    return path

def _save_fig(fig, path: str, dpi: int = 300):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    fig.savefig(path, dpi=dpi)
    plt.close(fig)
    return path

# =========================
# 1) 字体与绘图风格（中文为主）
# =========================
def set_chinese_plot_style():
    matplotlib.rcParams["pdf.fonttype"] = 42
    matplotlib.rcParams["ps.fonttype"]  = 42
    matplotlib.rcParams["axes.unicode_minus"] = False

    # 尽量在Windows上优先用微软雅黑/黑体；若没有就退化
    matplotlib.rcParams["font.sans-serif"] = [
        "Microsoft YaHei", "SimHei", "Arial", "DejaVu Sans"
    ]

    matplotlib.rcParams["font.size"] = FONT_BASE
    matplotlib.rcParams["axes.titlesize"] = FONT_SUB
    matplotlib.rcParams["axes.labelsize"] = FONT_BASE
    matplotlib.rcParams["xtick.labelsize"] = FONT_TICK
    matplotlib.rcParams["ytick.labelsize"] = FONT_TICK
    matplotlib.rcParams["legend.fontsize"] = FONT_TICK
    matplotlib.rcParams["figure.titlesize"] = FONT_TITLE
    matplotlib.rcParams["axes.grid"] = True
    matplotlib.rcParams["grid.alpha"] = 0.25
    matplotlib.rcParams["grid.linestyle"] = "--"

set_chinese_plot_style()

# =========================
# 2) 读数据：兼容“单Sheet长表” or “多Sheet年表”
#    ✅ FIX: 兼容 “年份在首行、表头在次行” 的sheet
# =========================
def _clean_colname(s: str) -> str:
    s = str(s).strip()
    s = re.sub(r"\s+", "", s)
    s = s.replace("（", "(").replace("）", ")")
    return s

def _norm_city_name(x):
    if pd.isna(x):
        return x
    s = str(x).strip()
    s = re.sub(r"\s+", "", s)
    s = s.replace("市", "")
    s = s.replace("地区", "")
    s = s.replace("自治州", "")
    return s

def _find_col(cols, patterns):
    """
    cols: list of original columns (string)
    patterns: list of regex patterns
    return best matched col name or None
    """
    for p in patterns:
        rg = re.compile(p, re.I)
        for c in cols:
            if rg.search(_clean_colname(c)):
                return c
    return None

def _parse_year_from_sheetname(sh):
    m = re.search(r"(19\d{2}|20\d{2})", str(sh))
    if m:
        y = int(m.group(1))
        if 1900 <= y <= 2100:
            return y
    return None

def _infer_year_from_preview(preview: pd.DataFrame):
    """
    在sheet前几行任意单元格里找 4位年份（1900-2100），返回出现次数最多的那个年份
    """
    vals = preview.to_numpy().ravel()
    yrs = []
    for v in vals:
        if pd.isna(v):
            continue
        iv = None
        if isinstance(v, (int, np.integer)):
            iv = int(v)
        elif isinstance(v, float) and float(v).is_integer():
            iv = int(v)
        else:
            s = str(v).strip()
            m = re.search(r"(19\d{2}|20\d{2})", s)
            if m:
                iv = int(m.group(1))
        if iv is not None and 1900 <= iv <= 2100:
            yrs.append(iv)
    if not yrs:
        return None
    uniq, cnt = np.unique(yrs, return_counts=True)
    return int(uniq[np.argmax(cnt)])

def _infer_header_row(preview: pd.DataFrame):
    """
    自动判断真实表头在哪一行：给含“城市/GDP/产业”等关键词的行打分，取最高分行
    """
    city_kw = ["城市", "city", "地市", "地级市", "地区", "名称"]
    gdp_kw  = ["GDP", "生产总值", "国内生产总值", "地区生产总值"]
    year_kw = ["年份", "年度", "year", "时间"]

    best_i = 0
    best_s = -1
    for i in range(preview.shape[0]):
        row = preview.iloc[i].tolist()
        t = _clean_colname(" ".join([str(x) for x in row if not pd.isna(x)]))
        s = 0
        if any(k in t for k in city_kw): s += 3
        if any(k in t for k in gdp_kw):  s += 3
        if any(k in t for k in year_kw): s += 1
        for k in ["第一产业", "第二产业", "第三产业", "一产", "二产", "三产"]:
            if k in t: s += 1
        if ("比重" in t) or ("占比" in t) or ("比例" in t):
            s += 1
        if s > best_s:
            best_s = s
            best_i = i

    # 阈值：>=4 基本可认定为表头行，否则退化为第0行
    return best_i if best_s >= 4 else 0

def _read_sheet_smart(xlsx_path: str, sheet_name: str, preview_rows: int = 12):
    """
    读取单个sheet：自动识别年份与表头行
    return df, year_hint
    """
    preview = pd.read_excel(
        xlsx_path, sheet_name=sheet_name,
        header=None, nrows=preview_rows, engine="openpyxl"
    )

    # 年份：优先sheet名，其次sheet内容
    year_hint = _parse_year_from_sheetname(sheet_name)
    if year_hint is None:
        year_hint = _infer_year_from_preview(preview)

    # 表头行：自动推断
    header_row = _infer_header_row(preview)

    df = pd.read_excel(
        xlsx_path, sheet_name=sheet_name,
        header=header_row, engine="openpyxl"
    )
    df = df.dropna(axis=1, how="all").dropna(axis=0, how="all")
    df.columns = [str(c).strip() for c in df.columns]

    # 丢掉 Unnamed:xx 空列
    keep = [not re.match(r"^Unnamed", str(c), flags=re.I) for c in df.columns]
    df = df.loc[:, keep]

    return df, year_hint

def load_panel_any_excel(xlsx_path: str) -> pd.DataFrame:
    """
    返回标准面板：city, year, gdp, p1, p2, p3（百分比0-100）
    """
    xls = pd.ExcelFile(xlsx_path, engine="openpyxl")
    sheets = xls.sheet_names

    frames = []
    debug = []

    # 逐sheet智能读取并归一到长表
    for sh in sheets:
        try:
            df, year_hint = _read_sheet_smart(xlsx_path, sh)
        except Exception as e:
            debug.append((sh, "read_fail", str(e)))
            continue

        if df is None or df.shape[0] == 0 or df.shape[1] == 0:
            debug.append((sh, "empty", ""))
            continue

        cols = list(df.columns)
        col_city = _find_col(cols, [r"^city$", r"城市", r"地区", r"地市", r"地级市", r"名称"])
        col_year = _find_col(cols, [r"^year$", r"年份", r"年度", r"时间"])

        # 情形A：sheet本身就是长表（含城市+年份列）
        if col_city is not None and col_year is not None:
            frames.append(df)
            debug.append((sh, "long", ""))
            continue

        # 情形B：sheet是“单年表”（无年份列，但可从sheet名或内容识别年份）
        if col_city is not None and (year_hint is not None):
            df2 = df.copy()
            df2["year"] = int(year_hint)
            frames.append(df2)
            debug.append((sh, "yearsheet", f"year={year_hint}"))
            continue

        # 否则跳过
        debug.append((sh, "skip", f"city={col_city} yearcol={col_year} yearhint={year_hint}"))

    if len(frames) == 0:
        msg = "Excel中无法识别有效数据：既未找到长表sheet，也未找到可解析年份的多sheet。"
        msg += "\n你可以检查：是否存在“城市”列；若sheet名不是年份，sheet内是否有一个单元格写年份；表头是否在第1行而非第0行。"
        msg += "\n调试信息(前10个sheet):\n" + "\n".join([str(x) for x in debug[:10]])
        raise ValueError(msg)

    raw = pd.concat(frames, ignore_index=True)

    # 统一识别关键列
    cols = list(raw.columns)

    col_city = _find_col(cols, [r"^city$", r"城市", r"地区", r"地市", r"地级市", r"名称"])
    if col_city is None:
        raise KeyError("未找到“城市”列（可能列名不是城市/地区/city）。请检查Excel表头。")

    col_year = _find_col(cols, [r"^year$", r"年份", r"年度", r"时间"])
    if col_year is None:
        raise KeyError("未找到“年份”列，且sheet名/内容也无法识别年份。")

    col_gdp  = _find_col(cols, [r"^gdp$", r"国内生产总值", r"地区生产总值", r"生产总值", r"GDP"])
    if col_gdp is None:
        raise KeyError("未找到GDP列（可能列名不是GDP/国内生产总值/地区生产总值）。请检查Excel表头。")

    # 一二三产：可能是“比重/占比”或“增加值”
    col_p1 = _find_col(cols, [r"第一产业.*(比重|占比|比例)", r"一产.*(比重|占比|比例)", r"第一产业"])
    col_p2 = _find_col(cols, [r"第二产业.*(比重|占比|比例)", r"二产.*(比重|占比|比例)", r"第二产业"])
    col_p3 = _find_col(cols, [r"第三产业.*(比重|占比|比例)", r"三产.*(比重|占比|比例)", r"第三产业"])

    if col_p1 is None or col_p2 is None or col_p3 is None:
        raise KeyError("未找到完整的一/二/三产业列（至少需要三列）。请检查列名是否包含：第一产业/第二产业/第三产业。")

    use = raw[[col_city, col_year, col_gdp, col_p1, col_p2, col_p3]].copy()

    use = use.rename(columns={
        col_city: "city",
        col_year: "year",
        col_gdp : "gdp_raw",
        col_p1  : "s1_raw",
        col_p2  : "s2_raw",
        col_p3  : "s3_raw",
    })

    use["city"] = use["city"].apply(_norm_city_name)

    # 数值化
    for c in ["gdp_raw", "s1_raw", "s2_raw", "s3_raw", "year"]:
        use[c] = pd.to_numeric(use[c], errors="coerce")

    use = use.dropna(subset=["city", "year", "gdp_raw", "s1_raw", "s2_raw", "s3_raw"])
    use["year"] = use["year"].astype(int)

    # 将一二三产统一转为“百分比(0-100)”
    ssum = use["s1_raw"] + use["s2_raw"] + use["s3_raw"]
    med_sum = float(np.nanmedian(ssum.values))

    # 情况A：大致和=100
    is_pct  = (med_sum > 80) and (med_sum < 120)
    # 情况B：大致和=1
    is_frac = (med_sum > 0.8) and (med_sum < 1.2)

    if is_pct:
        use["p1"] = use["s1_raw"]
        use["p2"] = use["s2_raw"]
        use["p3"] = use["s3_raw"]
    elif is_frac:
        use["p1"] = use["s1_raw"] * 100.0
        use["p2"] = use["s2_raw"] * 100.0
        use["p3"] = use["s3_raw"] * 100.0
    else:
        # 视作增加值，计算结构占比
        use["p1"] = use["s1_raw"] / ssum * 100.0
        use["p2"] = use["s2_raw"] / ssum * 100.0
        use["p3"] = use["s3_raw"] / ssum * 100.0

    use = use.rename(columns={"gdp_raw": "gdp"})
    use = use[["city", "year", "gdp", "p1", "p2", "p3"]].copy()

    # 去重/排序
    use = use.drop_duplicates(subset=["city", "year"])
    use = use.sort_values(["city", "year"]).reset_index(drop=True)
    return use

# =========================
# 3) 相似性、增长、门槛回归
# =========================
def cosine_similarity(a, b):
    a = np.asarray(a, dtype=float)
    b = np.asarray(b, dtype=float)
    na = np.linalg.norm(a)
    nb = np.linalg.norm(b)
    if na == 0 or nb == 0:
        return np.nan
    return float(np.dot(a, b) / (na * nb))

def build_similarity_matrix(df_year):
    """
    df_year: 某一年，包含 city,p1,p2,p3
    返回：cities(list), sim_mat(nxn)
    """
    cities = list(df_year["city"].values)
    X = df_year[["p1", "p2", "p3"]].values / 100.0
    n = len(cities)
    sim = np.full((n, n), np.nan, dtype=float)
    for i in range(n):
        for j in range(n):
            sim[i, j] = cosine_similarity(X[i], X[j])
    return cities, sim

def annual_log_growth(g0, g1, years):
    if g0 <= 0 or g1 <= 0 or years <= 0:
        return np.nan
    return (math.log(g1) - math.log(g0)) / years * 100.0  # 年均log增长(%)

def stage_panel(panel, y0, y1):
    """
    返回某阶段城市层数据：city, gdp0,gdp1,growth, p1_0,p2_0,p3_0,p1_1..., upgrade_p3(百分点)
    """
    a = panel[panel["year"] == y0][["city", "gdp", "p1", "p2", "p3"]].rename(
        columns={"gdp": "gdp0", "p1": "p1_0", "p2": "p2_0", "p3": "p3_0"}
    )
    b = panel[panel["year"] == y1][["city", "gdp", "p1", "p2", "p3"]].rename(
        columns={"gdp": "gdp1", "p1": "p1_1", "p2": "p2_1", "p3": "p3_1"}
    )
    m = a.merge(b, on="city", how="inner")
    m["growth"] = m.apply(lambda r: annual_log_growth(r["gdp0"], r["gdp1"], y1 - y0), axis=1)
    m["upgrade_p3"] = (m["p3_1"] - m["p3_0"])  # 第三产业比重变化（百分点）
    m = m.dropna(subset=["growth"])
    return m

def similarity_to_core(panel, year, core_city):
    """
    返回 series: index=city, value=sim(city, core_city) at that year
    """
    dfy = panel[panel["year"] == year][["city", "p1", "p2", "p3"]].copy()
    dfy = dfy.dropna()
    dfy["city"] = dfy["city"].apply(_norm_city_name)
    core_city = _norm_city_name(core_city)

    if core_city not in set(dfy["city"]):
        return None

    core_vec = dfy.loc[dfy["city"] == core_city, ["p1", "p2", "p3"]].values[0] / 100.0
    sims = {}
    for _, r in dfy.iterrows():
        v = np.array([r["p1"], r["p2"], r["p3"]], dtype=float) / 100.0
        sims[r["city"]] = cosine_similarity(v, core_vec)
    return pd.Series(sims, name=f"sim_{core_city}")

def ols_fit(X, y):
    """
    X: (n,k), y:(n,)
    return beta(k,), yhat(n,), resid(n,), sse, r2, se_beta(k,)
    """
    X = np.asarray(X, float)
    y = np.asarray(y, float)
    n, k = X.shape
    beta, *_ = np.linalg.lstsq(X, y, rcond=None)
    yhat = X @ beta
    resid = y - yhat
    sse = float(np.sum(resid**2))
    sst = float(np.sum((y - y.mean())**2))
    r2 = np.nan if sst == 0 else 1 - sse / sst

    # 估计标准误（同方差）
    dof = max(n - k, 1)
    sigma2 = sse / dof
    XtX_inv = np.linalg.pinv(X.T @ X)
    varb = sigma2 * XtX_inv
    se = np.sqrt(np.diag(varb))
    return beta, yhat, resid, sse, r2, se

def threshold_regression(y, x, q=None, trim=0.15, n_boot=500, seed=0):
    """
    单阈值门槛回归（允许截距分段）：
      y = a1 + b1*x  (q<=g)
      y = a2 + b2*x  (q> g)
    若q为None，则q=x（常用：自阈值）

    返回 dict：
      gamma, beta(a1, a2, b1, b2), se, r2, sse, lr_stat, p_boot
    """
    rng = np.random.default_rng(seed)
    y = np.asarray(y, float)
    x = np.asarray(x, float)
    if q is None:
        q = x.copy()
    else:
        q = np.asarray(q, float)

    mask = np.isfinite(y) & np.isfinite(x) & np.isfinite(q)
    y = y[mask]; x = x[mask]; q = q[mask]
    n = len(y)
    if n < 10:
        return None

    # 候选阈值：去掉两端trim，保证两侧都有样本
    qs = np.sort(q)
    lo = int(math.floor(trim * n))
    hi = int(math.ceil((1 - trim) * n)) - 1
    cand = np.unique(qs[lo:hi+1])
    if len(cand) < 5:
        return None

    def fit_given_gamma(gamma, yvec):
        D = (q > gamma).astype(float)  # 1 if high regime
        # [1, D, x*(1-D), x*D] -> a1, (a2-a1), b1, b2
        X = np.column_stack([np.ones_like(x), D, x * (1 - D), x * D])
        beta, yhat, resid, sse, r2, se = ols_fit(X, yvec)
        # 还原 a1,a2,b1,b2
        a1 = beta[0]
        a2 = beta[0] + beta[1]
        b1 = beta[2]
        b2 = beta[3]
        # 标准误按原参数近似返回（a2的se用delta法）
        se_a1 = se[0]
        se_a2 = float(np.sqrt(se[0]**2 + se[1]**2))
        se_b1 = se[2]
        se_b2 = se[3]
        out = {
            "gamma": gamma, "a1": a1, "a2": a2, "b1": b1, "b2": b2,
            "se_a1": se_a1, "se_a2": se_a2, "se_b1": se_b1, "se_b2": se_b2,
            "sse": sse, "r2": r2, "yhat": yhat
        }
        return out

    # 1) 阈值网格搜索
    best = None
    for g in cand:
        cur = fit_given_gamma(g, y)
        if best is None or cur["sse"] < best["sse"]:
            best = cur

    # 2) 线性无阈值模型（用于LR与bootstrap）
    X0 = np.column_stack([np.ones_like(x), x])
    beta0, yhat0, resid0, sse0, r2_0, se0 = ols_fit(X0, y)

    # 3) LR统计量（Hansen思想）：(SSE0 - SSE1)/sigma^2
    k1 = 4
    sigma2 = best["sse"] / max(n - k1, 1)
    lr_obs = (sse0 - best["sse"]) / max(sigma2, 1e-12)

    # 4) bootstrap p-value：在原假设下重采样残差
    lr_boot = []
    for _ in range(int(n_boot)):
        e_star = rng.choice(resid0, size=n, replace=True)
        y_star = yhat0 + e_star

        # 重新搜索阈值（同cand集合）
        best_star = None
        for g in cand:
            cur = fit_given_gamma(g, y_star)
            if best_star is None or cur["sse"] < best_star["sse"]:
                best_star = cur

        # 线性模型SSE0*（在y_star上）
        beta0s, yhat0s, resid0s, sse0s, _, _ = ols_fit(X0, y_star)
        sigma2s = best_star["sse"] / max(n - k1, 1)
        lr_s = (sse0s - best_star["sse"]) / max(sigma2s, 1e-12)
        lr_boot.append(lr_s)

    lr_boot = np.asarray(lr_boot, float)
    p_boot = float(np.mean(lr_boot >= lr_obs))

    best["lr_stat"] = float(lr_obs)
    best["p_boot"]  = p_boot
    best["n"]       = int(n)
    best["sse0"]    = float(sse0)
    best["r2_0"]    = float(r2_0)
    return best

def _predict_threshold(res, x):
    x = np.asarray(x, float)
    g = res["gamma"]
    return np.where(x <= g, res["a1"] + res["b1"] * x, res["a2"] + res["b2"] * x)

def _draw_threshold_panel(ax, data: pd.DataFrame, xcol: str, ycol: str, res: dict,
                          xlabel: str, ylabel: str, title: str):
    """
    同一套绘制逻辑：用于“总图面板”以及“逐次单图”
    """
    ax.set_title(title)
    if (res is None) or (data is None) or (len(data) == 0):
        ax.text(0.5, 0.5, "样本不足，无法估计门槛模型", ha="center", va="center")
        ax.set_axis_off()
        return

    x = data[xcol].values
    y = data[ycol].values

    ax.scatter(x, y, s=45, alpha=0.85)

    gamma = res["gamma"]
    ax.axvline(gamma, linestyle="--", linewidth=1.8)

    xs = np.linspace(np.nanmin(x), np.nanmax(x), 200)
    yhat = _predict_threshold(res, xs)
    ax.plot(xs, yhat, linewidth=2.2)

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    txt = (
        f"γ={gamma:.3f}\n"
        f"b1={res['b1']:.3f}, b2={res['b2']:.3f}\n"
        f"R²={res['r2']:.3f}\n"
        f"p={res['p_boot']:.3f}"
    )
    ax.text(0.02, 0.98, txt, transform=ax.transAxes,
            ha="left", va="top",
            bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="0.5", alpha=0.85))

# =========================
# 4) 图4-1：增长与结构演变（4子图）
# =========================
def fig4_1(panel: pd.DataFrame):
    out_pdf = os.path.join(OUTPUT_DIR, "图4-1_成渝城市群经济增长与产业结构演变.pdf")
    out_csv = os.path.join(OUTPUT_DIR, "图4-1_数据.csv")
    out_csv_box = os.path.join(OUTPUT_DIR, "图4-1_d箱线图数据_long.csv")

    years = np.sort(panel["year"].unique())

    # a) 核心城市GDP
    def get_city_series(city):
        d = panel[panel["city"] == city].set_index("year").sort_index()
        return d.reindex(years)["gdp"]

    cd = get_city_series("成都")
    cq = get_city_series("重庆")

    # b) 城市群GDP总量（柱）+ 均值（折线）
    gdp_year = panel.groupby("year")["gdp"].agg(["sum", "mean", "median", "count"]).reindex(years)

    # c) 城市群平均产业结构（堆叠面积）
    share_year = panel.groupby("year")[["p1", "p2", "p3"]].mean().reindex(years)

    # d) 第三产业比重分布（关键年份箱线图）
    box_data = []
    box_labels = []
    for y in KEY_YEARS:
        dy = panel[panel["year"] == y]["p3"].dropna().values
        if len(dy) > 0:
            box_data.append(dy)
            box_labels.append(str(y))

    # 输出CSV（用于正文/附录复核）
    out = pd.DataFrame({
        "year": years,
        "成都_gdp": cd.values,
        "重庆_gdp": cq.values,
        "城市群_gdp总量": gdp_year["sum"].values,
        "城市群_gdp均值": gdp_year["mean"].values,
        "城市群_gdp中位数": gdp_year["median"].values,
        "样本城市数": gdp_year["count"].values,
        "平均_第一产业比重": share_year["p1"].values,
        "平均_第二产业比重": share_year["p2"].values,
        "平均_第三产业比重": share_year["p3"].values,
    })
    _save_csv(out, out_csv, index=False)

    # ✅ (补充) 把(d)箱线图用到的“城市-年份-第三产业比重”明细也导出
    box_long = panel.loc[panel["year"].isin(KEY_YEARS), ["city", "year", "p3"]].copy()
    box_long = box_long.dropna().sort_values(["year", "p3"], ascending=[True, True]).reset_index(drop=True)
    _save_csv(box_long, out_csv_box, index=False)

    # 绘图
    fig, axes = plt.subplots(2, 2, figsize=(16, 9), constrained_layout=True)
    fig.suptitle("图4-1  成渝城市群经济增长与产业结构演变", y=0.98, fontsize=FONT_TITLE)

    # (a)
    ax = axes[0, 0]
    ax.plot(years, cd.values, marker="o", linewidth=2, label="成都")
    ax.plot(years, cq.values, marker="s", linewidth=2, linestyle="--", label="重庆")
    ax.set_title("(a) 核心城市GDP演变")
    ax.set_xlabel("年份")
    ax.set_ylabel("GDP（同原始单位）")
    ax.legend(loc="upper left", frameon=False)

    # (b) 柱+线（双轴）
    ax = axes[0, 1]
    ax2 = ax.twinx()
    ax.bar(years, gdp_year["sum"].values, alpha=0.30, label="城市群总量")
    ax2.plot(years, gdp_year["mean"].values, marker="s", linewidth=2, linestyle="--", label="城市均值")
    ax.set_title("(b) 成渝城市群GDP总量与均值")
    ax.set_xlabel("年份")
    ax.set_ylabel("GDP总量（同原始单位）")
    ax2.set_ylabel("GDP均值（同原始单位）")

    h1, l1 = ax.get_legend_handles_labels()
    h2, l2 = ax2.get_legend_handles_labels()
    ax.legend(h1 + h2, l1 + l2, loc="upper left", frameon=False)

    # (c)
    ax = axes[1, 0]
    ax.stackplot(years,
                 share_year["p1"].values,
                 share_year["p2"].values,
                 share_year["p3"].values,
                 labels=["第一产业", "第二产业", "第三产业"],
                 alpha=0.80)
    ax.set_title("(c) 城市群平均产业结构演变")
    ax.set_xlabel("年份")
    ax.set_ylabel("比重（%）")
    ax.set_ylim(0, 100)
    ax.legend(loc="upper left", frameon=False)

    # (d)
    ax = axes[1, 1]
    if len(box_data) > 0:
        ax.boxplot(box_data, labels=box_labels, patch_artist=True)
        ax.set_title("(d) 第三产业比重的城市分布（关键年份）")
        ax.set_xlabel("年份")
        ax.set_ylabel("第三产业比重（%）")
    else:
        ax.text(0.5, 0.5, "关键年份箱线图数据不足", ha="center", va="center")
        ax.set_axis_off()

    _save_fig(fig, out_pdf, dpi=300)

# =========================
# 5) 图4-2：相似性矩阵（2000/2010/2020）+ 不遮挡色标
#     ✅ 增加：每个年份单独输出 1 张 PDF + 1 份矩阵 CSV（逐次完成）
# =========================
def fig4_2(panel: pd.DataFrame):
    out_pdf = os.path.join(OUTPUT_DIR, "图4-2_城市间产业结构相似性矩阵_2000_2010_2020.pdf")
    out_csv = os.path.join(OUTPUT_DIR, "图4-2_相似性矩阵_long.csv")

    years_sel = [2000, 2010, 2020]
    mats = []
    long_rows = []

    for y in years_sel:
        dfy = panel[panel["year"] == y][["city", "p1", "p2", "p3"]].dropna().copy()
        cities, sim = build_similarity_matrix(dfy)
        mats.append((y, cities, sim))

        # long数据
        for i, ci in enumerate(cities):
            for j, cj in enumerate(cities):
                long_rows.append({"year": y, "city_i": ci, "city_j": cj, "sim": sim[i, j]})

        # ✅ 逐次：每个年份单独出一张PDF + 一个wide矩阵CSV
        if SAVE_INTERMEDIATE:
            # wide matrix csv（便于你按行列核对）
            wide = pd.DataFrame(sim, index=cities, columns=cities)
            csv_y = os.path.join(INTER_CSV_DIR, f"图4-2_{y}年相似性矩阵.csv")
            _save_csv(wide, csv_y, index=True)

            # 单年矩阵图
            fig = plt.figure(figsize=(9, 8), constrained_layout=True)
            ax = fig.add_subplot(111)
            im = ax.imshow(sim, cmap="viridis", aspect="equal")
            ax.set_title(f"图4-2（逐次） {y}年 城市间产业结构相似性矩阵")
            ax.set_xticks(range(len(cities)))
            ax.set_yticks(range(len(cities)))
            ax.set_xticklabels(cities, rotation=90)
            ax.set_yticklabels(cities)
            ax.grid(False)
            cb = fig.colorbar(im, ax=ax, shrink=0.85, pad=0.02)
            cb.set_label("产业结构相似度（余弦相似）")

            pdf_y = os.path.join(INTER_PDF_DIR, f"图4-2_{y}年相似性矩阵.pdf")
            _save_fig(fig, pdf_y, dpi=300)

    _save_csv(pd.DataFrame(long_rows), out_csv, index=False)

    # 统一色标范围（更稳健：按全部值的分位数）
    all_vals = np.concatenate([m[2].ravel() for m in mats])
    all_vals = all_vals[np.isfinite(all_vals)]
    vmin = float(np.quantile(all_vals, 0.05)) if len(all_vals) else 0.80
    vmax = float(np.quantile(all_vals, 0.95)) if len(all_vals) else 1.00

    # ✅ FIX：不要在 constrained_layout 下用 fig.add_axes 叠加色标轴（会压住标题/子图）
    # 改为：GridSpec 预留“色标行”，色标完全占用该行 → 不遮挡、布局稳定
    fig = plt.figure(figsize=(16, 6), constrained_layout=True)
    gs = fig.add_gridspec(nrows=2, ncols=3, height_ratios=[0.14, 0.86], hspace=0.05, wspace=0.05)

    # colorbar axis（整行）
    cax = fig.add_subplot(gs[0, :])

    # 三个热力图轴
    axes = [fig.add_subplot(gs[1, i]) for i in range(3)]

    fig.suptitle("图4-2  城市间产业结构相似性矩阵（2000/2010/2020）", y=0.98, fontsize=FONT_TITLE)

    im_for_cbar = None
    for ax, (y, cities, sim) in zip(axes, mats):
        im = ax.imshow(sim, vmin=vmin, vmax=vmax, cmap="viridis", aspect="equal")
        im_for_cbar = im

        ax.set_title(f"({chr(97+years_sel.index(y))}) {y}年")
        ax.set_xticks(range(len(cities)))
        ax.set_yticks(range(len(cities)))
        ax.set_xticklabels(cities, rotation=90)
        ax.set_yticklabels(cities)
        ax.grid(False)

    # 色标放在顶部行：刻度/标签放在上侧，避免靠近热力图产生“压住”的视觉
    cb = fig.colorbar(im_for_cbar, cax=cax, orientation="horizontal")
    cb.set_label("产业结构相似度（余弦相似）")
    cax.xaxis.set_ticks_position("top")
    cax.xaxis.set_label_position("top")
    cb.ax.tick_params(labelsize=FONT_TICK)

    _save_fig(fig, out_pdf, dpi=300)

# =========================
# 6) 图4-3：相似性门槛效应（4阶段×2核心）+ bootstrap p值 + 结果表
#     ✅ 增加：每个“阶段×核心”单独输出 1 张 PDF + 1 份CSV（逐次完成）
# =========================
def fig4_3_threshold_similarity(panel: pd.DataFrame):
    out_pdf  = os.path.join(OUTPUT_DIR, "图4-3_产业结构相似性门槛效应_4阶段x2核心.pdf")
    out_csvT = os.path.join(OUTPUT_DIR, "表4-门槛回归结果表_产业结构相似性_4阶段x2核心.csv")
    out_csvD = os.path.join(OUTPUT_DIR, "图4-3_回归面板数据.csv")

    rows_T = []
    rows_D = []

    fig, axes = plt.subplots(4, 2, figsize=(16, 18), constrained_layout=True)
    fig.suptitle("图4-3  产业结构相似性门槛效应（4阶段×2核心）", y=0.995, fontsize=FONT_TITLE)

    for i, (y0, y1) in enumerate(STAGES):
        st = stage_panel(panel, y0, y1)

        for j, core in enumerate(CORE_CITIES):
            ax = axes[i, j]
            title = f"({chr(97+i)}) {y0}-{y1}  核心：{core}"

            sim0 = similarity_to_core(panel, y0, core)
            if sim0 is None:
                ax.text(0.5, 0.5, f"{y0}年缺少核心城市：{core}", ha="center", va="center")
                ax.set_axis_off()
                continue

            data = st.merge(sim0.rename("sim0"), left_on="city", right_index=True, how="inner")
            data = data.dropna(subset=["sim0", "growth"])

            res = threshold_regression(
                y=data["growth"].values,
                x=data["sim0"].values,
                q=None,
                trim=0.15,
                n_boot=N_BOOT,
                seed=RANDOM_SEED + i * 10 + j
            )

            # 1) 总图面板绘制
            _draw_threshold_panel(
                ax=ax,
                data=data,
                xcol="sim0",
                ycol="growth",
                res=res,
                xlabel="与核心城市的产业结构相似度（阶段起点）",
                ylabel="年均GDP增长率（log，%）",
                title=title
            )

            # 2) 逐次：每个面板单独出图+数据（方便写作逐段引用）
            if SAVE_INTERMEDIATE and (res is not None) and (len(data) > 0):
                # 单图PDF
                fig1 = plt.figure(figsize=(7.2, 5.2), constrained_layout=True)
                ax1 = fig1.add_subplot(111)
                _draw_threshold_panel(
                    ax=ax1,
                    data=data,
                    xcol="sim0",
                    ycol="growth",
                    res=res,
                    xlabel="与核心城市的产业结构相似度（阶段起点）",
                    ylabel="年均GDP增长率（log，%）",
                    title=f"图4-3（逐次） {y0}-{y1} 核心：{core}"
                )
                pdf_one = os.path.join(INTER_PDF_DIR, f"图4-3_{y0}-{y1}_核心{core}.pdf")
                _save_fig(fig1, pdf_one, dpi=300)

                # 对应CSV：散点（城市点）+ 拟合线（200点）
                x = data["sim0"].values
                y = data["growth"].values
                gamma = res["gamma"]
                yhat_pt = _predict_threshold(res, x)
                regime = np.where(x <= gamma, "low", "high")

                df_points = pd.DataFrame({
                    "type": "point",
                    "stage": f"{y0}-{y1}",
                    "core": core,
                    "city": data["city"].values,
                    "x_sim0": x,
                    "y_growth": y,
                    "yhat": yhat_pt,
                    "regime": regime,
                    "gamma": gamma,
                    "a1": res["a1"], "a2": res["a2"], "b1": res["b1"], "b2": res["b2"],
                    "R2_threshold": res["r2"], "R2_linear": res["r2_0"],
                    "LR": res["lr_stat"], "p_boot": res["p_boot"],
                    "N": res["n"],
                })

                xs = np.linspace(np.nanmin(x), np.nanmax(x), 200)
                yhat_line = _predict_threshold(res, xs)
                df_line = pd.DataFrame({
                    "type": "fit",
                    "stage": f"{y0}-{y1}",
                    "core": core,
                    "city": np.nan,
                    "x_sim0": xs,
                    "y_growth": np.nan,
                    "yhat": yhat_line,
                    "regime": np.where(xs <= gamma, "low", "high"),
                    "gamma": gamma,
                    "a1": res["a1"], "a2": res["a2"], "b1": res["b1"], "b2": res["b2"],
                    "R2_threshold": res["r2"], "R2_linear": res["r2_0"],
                    "LR": res["lr_stat"], "p_boot": res["p_boot"],
                    "N": res["n"],
                })

                df_plot = pd.concat([df_points, df_line], ignore_index=True)
                csv_one = os.path.join(INTER_CSV_DIR, f"图4-3_{y0}-{y1}_核心{core}_数据.csv")
                _save_csv(df_plot, csv_one, index=False)

            # 3) 汇总表&面板数据（维持原输出）
            if res is not None:
                rows_T.append({
                    "阶段": f"{y0}-{y1}",
                    "核心城市": core,
                    "样本量N": res["n"],
                    "门槛值γ": res["gamma"],
                    "低区间截距a1": res["a1"],
                    "高区间截距a2": res["a2"],
                    "低区间系数b1": res["b1"],
                    "高区间系数b2": res["b2"],
                    "R2(门槛)": res["r2"],
                    "R2(线性)": res["r2_0"],
                    "LR统计量": res["lr_stat"],
                    "bootstrap_p值": res["p_boot"],
                })

                tmpD = data[["city", "growth", "sim0"]].copy()
                tmpD["stage"] = f"{y0}-{y1}"
                tmpD["core"]  = core
                rows_D.append(tmpD)

    if len(rows_T) > 0:
        _save_csv(pd.DataFrame(rows_T), out_csvT, index=False)
    if len(rows_D) > 0:
        _save_csv(pd.concat(rows_D, ignore_index=True), out_csvD, index=False)

    _save_fig(fig, out_pdf, dpi=300)

# =========================
# 7) 图4-4：产业结构升级（第三产业比重变化）门槛效应（4阶段）
#     ✅ 增加：每个阶段单独输出 1 张 PDF + 1 份CSV（逐次完成）
# =========================
def fig4_4_threshold_upgrade(panel: pd.DataFrame):
    out_pdf  = os.path.join(OUTPUT_DIR, "图4-4_产业结构升级门槛效应_分阶段.pdf")
    out_csvR = os.path.join(OUTPUT_DIR, "图4-4_门槛回归结果.csv")
    out_csvD = os.path.join(OUTPUT_DIR, "图4-4_回归面板数据.csv")

    rows_R = []
    rows_D = []

    fig, axes = plt.subplots(2, 2, figsize=(16, 10), constrained_layout=True)
    fig.suptitle("图4-4  产业结构升级与GDP增长的门槛效应（分阶段）", y=0.99, fontsize=FONT_TITLE)

    for k, (y0, y1) in enumerate(STAGES):
        ax = axes[k // 2, k % 2]
        st = stage_panel(panel, y0, y1)
        data = st.dropna(subset=["upgrade_p3", "growth"]).copy()

        res = threshold_regression(
            y=data["growth"].values,
            x=data["upgrade_p3"].values,
            q=None,
            trim=0.15,
            n_boot=N_BOOT,
            seed=RANDOM_SEED + 100 + k
        )

        title = f"({chr(97+k)}) {y0}-{y1}"
        _draw_threshold_panel(
            ax=ax,
            data=data,
            xcol="upgrade_p3",
            ycol="growth",
            res=res,
            xlabel="第三产业比重变化（百分点）",
            ylabel="年均GDP增长率（log，%）",
            title=title
        )

        if res is None:
            continue

        rows_R.append({
            "阶段": f"{y0}-{y1}",
            "样本量N": res["n"],
            "门槛值γ(百分点)": res["gamma"],
            "低区间截距a1": res["a1"],
            "高区间截距a2": res["a2"],
            "低区间系数b1": res["b1"],
            "高区间系数b2": res["b2"],
            "R2(门槛)": res["r2"],
            "R2(线性)": res["r2_0"],
            "LR统计量": res["lr_stat"],
            "bootstrap_p值": res["p_boot"],
        })

        tmpD = data[["city", "growth", "upgrade_p3"]].copy()
        tmpD["stage"] = f"{y0}-{y1}"
        rows_D.append(tmpD)

        # ✅ 逐次：单阶段出图+数据
        if SAVE_INTERMEDIATE and (len(data) > 0):
            # 单图PDF
            fig1 = plt.figure(figsize=(7.2, 5.2), constrained_layout=True)
            ax1 = fig1.add_subplot(111)
            _draw_threshold_panel(
                ax=ax1,
                data=data,
                xcol="upgrade_p3",
                ycol="growth",
                res=res,
                xlabel="第三产业比重变化（百分点）",
                ylabel="年均GDP增长率（log，%）",
                title=f"图4-4（逐次） {y0}-{y1}"
            )
            pdf_one = os.path.join(INTER_PDF_DIR, f"图4-4_{y0}-{y1}.pdf")
            _save_fig(fig1, pdf_one, dpi=300)

            # 对应CSV：散点+拟合线
            x = data["upgrade_p3"].values
            y = data["growth"].values
            gamma = res["gamma"]
            yhat_pt = _predict_threshold(res, x)
            regime = np.where(x <= gamma, "low", "high")

            df_points = pd.DataFrame({
                "type": "point",
                "stage": f"{y0}-{y1}",
                "city": data["city"].values,
                "x_upgrade_p3": x,
                "y_growth": y,
                "yhat": yhat_pt,
                "regime": regime,
                "gamma": gamma,
                "a1": res["a1"], "a2": res["a2"], "b1": res["b1"], "b2": res["b2"],
                "R2_threshold": res["r2"], "R2_linear": res["r2_0"],
                "LR": res["lr_stat"], "p_boot": res["p_boot"],
                "N": res["n"],
            })

            xs = np.linspace(np.nanmin(x), np.nanmax(x), 200)
            yhat_line = _predict_threshold(res, xs)
            df_line = pd.DataFrame({
                "type": "fit",
                "stage": f"{y0}-{y1}",
                "city": np.nan,
                "x_upgrade_p3": xs,
                "y_growth": np.nan,
                "yhat": yhat_line,
                "regime": np.where(xs <= gamma, "low", "high"),
                "gamma": gamma,
                "a1": res["a1"], "a2": res["a2"], "b1": res["b1"], "b2": res["b2"],
                "R2_threshold": res["r2"], "R2_linear": res["r2_0"],
                "LR": res["lr_stat"], "p_boot": res["p_boot"],
                "N": res["n"],
            })

            df_plot = pd.concat([df_points, df_line], ignore_index=True)
            csv_one = os.path.join(INTER_CSV_DIR, f"图4-4_{y0}-{y1}_数据.csv")
            _save_csv(df_plot, csv_one, index=False)

    if len(rows_R) > 0:
        _save_csv(pd.DataFrame(rows_R), out_csvR, index=False)
    if len(rows_D) > 0:
        _save_csv(pd.concat(rows_D, ignore_index=True), out_csvD, index=False)

    _save_fig(fig, out_pdf, dpi=300)

# =========================
# 8) 图4-5：双核结构位置—增长/升级（分阶段气泡图，色标不遮挡）
#     ✅ 增加：每个阶段单独输出 1 张 PDF + 1 份CSV（逐次完成）
# =========================
def fig4_5_dualcore(panel: pd.DataFrame):
    out_pdf = os.path.join(OUTPUT_DIR, "图4-5_成渝双核结构位置与增长升级_分阶段.pdf")
    out_csv = os.path.join(OUTPUT_DIR, "图4-5_双核结构指标.csv")

    all_rows = []

    fig, axes = plt.subplots(2, 2, figsize=(16, 12), constrained_layout=True)
    fig.suptitle("图4-5  成渝双核结构位置与城市长期增长 / 升级（分阶段）", y=0.99, fontsize=FONT_TITLE)

    sc_for_cbar = None

    for k, (y0, y1) in enumerate(STAGES):
        ax = axes[k // 2, k % 2]
        st = stage_panel(panel, y0, y1)

        sim_cd = similarity_to_core(panel, y1, "成都")  # 期末位置
        sim_cq = similarity_to_core(panel, y1, "重庆")
        if sim_cd is None or sim_cq is None:
            ax.text(0.5, 0.5, "缺少核心城市数据，无法绘制", ha="center", va="center")
            ax.set_axis_off()
            continue

        data = st.merge(sim_cd.rename("sim_cd"), left_on="city", right_index=True, how="inner") \
                 .merge(sim_cq.rename("sim_cq"), left_on="city", right_index=True, how="inner")

        g = data["growth"].values
        g_abs = np.abs(g)
        if np.nanmax(g_abs) > 0:
            size = 80 + 420 * (g_abs / np.nanmax(g_abs))
        else:
            size = np.full_like(g_abs, 180.0)

        cval = data["upgrade_p3"].values

        sc = ax.scatter(data["sim_cd"].values, data["sim_cq"].values,
                        s=size, c=cval, cmap="coolwarm", alpha=0.85, edgecolor="k", linewidth=0.5)
        sc_for_cbar = sc

        ax.set_title(f"({chr(97+k)}) {y0}-{y1}")
        ax.set_xlabel("与成都产业结构相似度（期末）")
        ax.set_ylabel("与重庆产业结构相似度（期末）")

        vx = float(np.nanmedian(data["sim_cd"].values))
        vy = float(np.nanmedian(data["sim_cq"].values))
        ax.axvline(vx, linestyle="--", linewidth=1.2, alpha=0.8)
        ax.axhline(vy, linestyle="--", linewidth=1.2, alpha=0.8)

        for name in ["成都", "重庆"]:
            if name in set(data["city"]):
                r = data.loc[data["city"] == name].iloc[0]
                ax.text(r["sim_cd"], r["sim_cq"], name, fontsize=FONT_TICK, weight="bold")

        top_idx = np.argsort(-np.abs(data["growth"].values))[:2]
        for idx in top_idx:
            r = data.iloc[idx]
            ax.text(r["sim_cd"], r["sim_cq"], str(r["city"]), fontsize=FONT_TICK)

        tmp = data[["city", "sim_cd", "sim_cq", "growth", "upgrade_p3", "gdp0", "gdp1"]].copy()
        tmp["stage"] = f"{y0}-{y1}"
        tmp["bubble_size"] = size
        tmp["median_sim_cd"] = vx
        tmp["median_sim_cq"] = vy
        all_rows.append(tmp)

        # ✅ 逐次：单阶段PDF + CSV
        if SAVE_INTERMEDIATE and (len(tmp) > 0):
            # 单图PDF
            fig1 = plt.figure(figsize=(7.2, 6.2), constrained_layout=True)
            ax1 = fig1.add_subplot(111)
            sc1 = ax1.scatter(tmp["sim_cd"].values, tmp["sim_cq"].values,
                              s=tmp["bubble_size"].values, c=tmp["upgrade_p3"].values,
                              cmap="coolwarm", alpha=0.85, edgecolor="k", linewidth=0.5)
            ax1.set_title(f"图4-5（逐次） {y0}-{y1}")
            ax1.set_xlabel("与成都产业结构相似度（期末）")
            ax1.set_ylabel("与重庆产业结构相似度（期末）")
            ax1.axvline(vx, linestyle="--", linewidth=1.2, alpha=0.8)
            ax1.axhline(vy, linestyle="--", linewidth=1.2, alpha=0.8)
            cb1 = fig1.colorbar(sc1, ax=ax1, shrink=0.85, pad=0.02)
            cb1.set_label("第三产业比重变化（百分点）")

            pdf_one = os.path.join(INTER_PDF_DIR, f"图4-5_{y0}-{y1}.pdf")
            _save_fig(fig1, pdf_one, dpi=300)

            # 对应CSV
            csv_one = os.path.join(INTER_CSV_DIR, f"图4-5_{y0}-{y1}_数据.csv")
            _save_csv(tmp, csv_one, index=False)

    if sc_for_cbar is not None:
        cb = fig.colorbar(sc_for_cbar, ax=axes.ravel().tolist(), shrink=0.90, pad=0.02)
        cb.set_label("第三产业比重变化（百分点）")

    if len(all_rows) > 0:
        _save_csv(pd.concat(all_rows, ignore_index=True), out_csv, index=False)
    _save_fig(fig, out_pdf, dpi=300)

# =========================
# 9) 主流程
# =========================
def main():
    print(">> 读取数据:", INPUT_XLSX)
    panel = load_panel_any_excel(INPUT_XLSX)

    # 基本检查：核心城市
    cities = set(panel["city"].unique())
    for c in CORE_CITIES:
        if c not in cities:
            raise ValueError(f"数据中缺少核心城市：{c}（请检查城市命名，如“成都市/重庆市”是否被识别）")

    y_min, y_max = int(panel["year"].min()), int(panel["year"].max())
    print(f">> 年份范围：{y_min}-{y_max}，城市数：{panel['city'].nunique()}，样本量：{len(panel)}")

    print(">> 输出 图4-1 ...")
    fig4_1(panel)

    print(">> 输出 图4-2 ...")
    fig4_2(panel)

    print(">> 输出 图4-3（门槛回归：相似性×2核心）...")
    fig4_3_threshold_similarity(panel)

    print(">> 输出 图4-4（门槛回归：升级）...")
    fig4_4_threshold_upgrade(panel)

    print(">> 输出 图4-5（双核位置）...")
    fig4_5_dualcore(panel)

    if SAVE_INTERMEDIATE:
        print(">> 逐次计算的单图 PDF/CSV 已输出到：", INTERMEDIATE_DIR)

    print("\n✅ 全部完成：所有PDF与CSV已输出到：", OUTPUT_DIR)

if __name__ == "__main__":
    main()
